-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🔁 🦈 Support iterative GRPO #2700
Conversation
Co-authored-by: Quentin Gallouédec <[email protected]>
Nice! Can you try locally with Multi GPU / DeepSpeed ZeRO 1/2/3? If you don't have the hardware, I can do it. |
In the DeepSeek-R1 paper, I think they sync the ref after each epoch, no? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@qgallouedec for the update, I thiiiink they do the update after one complete iteration (epoc), but I am not sure because I think this way there might be a conflict, because the default Maybe I am misunderstanding? ![]() |
Note that this algorithm and the ref_update discussion is from the DeepSeekMath paper where they discussed the grpo math. but the question still remains!🤔 |
Don't bother with multi gpu, I'm go a test myself I think we understand similarly. I'm wondering what the user would expect. Let me make some tests. I'll come back to you. |
@qgallouedec Did you get to test this by any chance ? 🤔 |
Not yet, will do asap |
Actually I don't have time to test unfortunately, but I think it's really worth:
Do you want to handle 1. @shirinyamani? In the meantime I'll merge this one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @shirinyamani!
@qgallouedec Goal: is to have a param like How to build: Option 1): to override the current class SyncRefModelCallback(TrainerCallback):
def __init__(
self,
ref_model: Union[PreTrainedModel, torch.nn.Module],
accelerator: Optional[Accelerator],
):
self.accelerator = accelerator
self.ref_model = ref_model
@staticmethod
def _sync_target_model(model, target_model, alpha):
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)
@staticmethod
def sync_target_model(model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
with deepspeed.zero.GatheredParameters(
list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
):
if deepspeed.comm.get_rank() == 0:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
else:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
def on_step_end(self, args, state, control, **kwargs):
model: PreTrainedModel = kwargs["model"]
if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0:
if self.accelerator:
model = self.accelerator.unwrap_model(model)
self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) with the changes would be sth like; def on_step_end(self, args, state, control, **kwargs):
model: PreTrainedModel = kwargs["model"]
# Calculate total steps per epoch
steps_per_epoch = state.max_steps // args.num_train_epochs
# Determine if we should sync based on ref_model_sync_epochs
if isinstance(self.ref_model_sync_epochs, int):
# Sync based on integer number of epochs
should_sync = state.global_step % (self.ref_model_sync_epochs * steps_per_epoch) == 0
elif isinstance(self.ref_model_sync_epochs, float):
# Sync based on fraction of total epochs
should_sync = (state.global_step / steps_per_epoch) % (self.ref_model_sync_epochs * args.num_train_epochs) == 0
else:
raise ValueError("ref_model_sync_epochs must be an int or a float")
if self.ref_model is not None and should_sync:
if self.accelerator:
model = self.accelerator.unwrap_model(model)
self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) This might work if my understanding of Option 2:) is to add the Thoughts ? 💭 |
One more Question for you; @qgallouedec If the ref_model is getting updated so frequently, would it be same as not having ref_model at all ? 🤔 |
What does this PR do?
Following the thread of this issue#2684 and based on Deepseek paper we came to conclude that we need to add a feature which every once in a while (
ref_model_sync_steps
) can iteratively update the reference model.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
@qgallouedec
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.